"""
Test Data Consistency Analysis Across Runs and Iterations
=========================================================

This script analyzes test data from the 'results' folder in SageMaker to determine 
if all test_data are the same across runs and iterations.

Expected folder structure:
results/
├── run_1/
│   ├── iteration_1/
│   │   └── test_data.csv (or .json, .pkl)
│   ├── iteration_2/
│   │   └── test_data.csv
│   └── ...
├── run_2/
│   ├── iteration_1/
│   │   └── test_data.csv
│   └── ...
└── ...

Usage for SageMaker notebook:
```python
from analyze_test_data_consistency import TestDataConsistencyAnalyzer

# Initialize analyzer
analyzer = TestDataConsistencyAnalyzer('results')

# Run analysis
results = analyzer.analyze_consistency()

# Display results
analyzer.print_summary(results)
```
"""

import os
import pandas as pd
import numpy as np
import json
import pickle
import hashlib
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
import warnings

warnings.filterwarnings('ignore')

class TestDataConsistencyAnalyzer:
    """Analyzes consistency of test data across multiple runs and iterations"""
    
    def __init__(self, base_path: str = 'results'):
        """
        Initialize the analyzer
        
        Args:
            base_path: Path to the results directory
        """
        self.base_path = Path(base_path)
        self.supported_formats = ['.csv', '.json', '.pkl', '.pickle', '.parquet']
    
    def find_test_data_files(self) -> Dict[str, Dict[str, str]]:
        """
        Find all test_data files across runs and iterations
        
        Returns:
            Dictionary with structure: {run_id: {iteration_id: file_path}}
        """
        files = {}
        
        if not self.base_path.exists():
            print(f"❌ Base path '{self.base_path}' does not exist")
            return files
        
        # Search for run directories
        for run_dir in self.base_path.iterdir():
            if run_dir.is_dir() and run_dir.name.startswith('run'):
                run_id = run_dir.name
                files[run_id] = {}
                
                # Search for iteration directories within each run
                for iter_dir in run_dir.iterdir():
                    if iter_dir.is_dir() and iter_dir.name.startswith('iter'):
                        iter_id = iter_dir.name
                        
                        # Look for test_data files
                        for file_path in iter_dir.iterdir():
                            if (file_path.is_file() and 
                                'test_data' in file_path.name.lower() and
                                file_path.suffix.lower() in self.supported_formats):
                                files[run_id][iter_id] = str(file_path)
                                break
        
        return files
    
    def load_test_data(self, file_path: str) -> Optional[pd.DataFrame]:
        """
        Load test data from various formats
        
        Args:
            file_path: Path to the test data file
            
        Returns:
            Loaded data as DataFrame or None if loading fails
        """
        try:
            file_path = Path(file_path)
            suffix = file_path.suffix.lower()
            
            if suffix == '.csv':
                return pd.read_csv(file_path)
            elif suffix == '.json':
                with open(file_path, 'r') as f:
                    data = json.load(f)
                return pd.DataFrame(data) if isinstance(data, list) else pd.json_normalize(data)
            elif suffix in ['.pkl', '.pickle']:
                with open(file_path, 'rb') as f:
                    data = pickle.load(f)
                return pd.DataFrame(data) if not isinstance(data, pd.DataFrame) else data
            elif suffix == '.parquet':
                return pd.read_parquet(file_path)
            else:
                print(f"⚠️ Unsupported file format: {suffix}")
                return None
                
        except Exception as e:
            print(f"❌ Error loading {file_path}: {e}")
            return None
    
    def compute_data_hash(self, df: pd.DataFrame) -> str:
        """
        Compute a hash of the DataFrame for comparison
        
        Args:
            df: DataFrame to hash
            
        Returns:
            MD5 hash of the DataFrame content
        """
        try:
            # Sort by all columns to ensure consistent ordering
            df_sorted = df.sort_values(by=list(df.columns)).reset_index(drop=True)
            
            # Convert to string and compute hash
            content = df_sorted.to_string(index=False)
            return hashlib.md5(content.encode()).hexdigest()
        except Exception as e:
            print(f"⚠️ Error computing hash: {e}")
            return "error"
    
    def compare_dataframes(self, df1: pd.DataFrame, df2: pd.DataFrame) -> Dict[str, Any]:
        """
        Compare two DataFrames and return detailed comparison results
        
        Args:
            df1, df2: DataFrames to compare
            
        Returns:
            Dictionary with comparison results
        """
        comparison = {
            'identical': False,
            'shape_match': False,
            'columns_match': False,
            'content_match': False,
            'differences': []
        }
        
        # Check shapes
        if df1.shape == df2.shape:
            comparison['shape_match'] = True
        else:
            comparison['differences'].append(f"Shape mismatch: {df1.shape} vs {df2.shape}")
        
        # Check columns
        if list(df1.columns) == list(df2.columns):
            comparison['columns_match'] = True
        else:
            comparison['differences'].append(f"Column mismatch: {list(df1.columns)} vs {list(df2.columns)}")
        
        # Check content if structure matches
        if comparison['shape_match'] and comparison['columns_match']:
            try:
                if df1.equals(df2):
                    comparison['content_match'] = True
                    comparison['identical'] = True
                else:
                    # Find specific differences
                    diff_mask = df1 != df2
                    diff_count = diff_mask.sum().sum()
                    comparison['differences'].append(f"Content differences: {diff_count} cells differ")
            except Exception as e:
                comparison['differences'].append(f"Error comparing content: {e}")
        
        return comparison
    
    def analyze_consistency(self) -> Dict[str, Any]:
        """
        Main analysis function to check consistency across all runs and iterations
        
        Returns:
            Dictionary with analysis results
        """
        print("🔍 Searching for test data files...")
        files = self.find_test_data_files()
        
        if not files:
            return {'error': 'No test data files found'}
        
        # Print found files
        total_files = sum(len(iterations) for iterations in files.values())
        print(f"📁 Found {total_files} test data files across {len(files)} runs")
        
        for run_id, iterations in files.items():
            print(f"   {run_id}: {len(iterations)} iterations")
        
        # Load all data and compute hashes
        print("\n📊 Loading and analyzing data...")
        data_info = {}
        hashes = {}
        
        for run_id, iterations in files.items():
            data_info[run_id] = {}
            hashes[run_id] = {}
            
            for iter_id, file_path in iterations.items():
                print(f"   Loading {run_id}/{iter_id}...")
                df = self.load_test_data(file_path)
                
                if df is not None:
                    data_hash = self.compute_data_hash(df)
                    data_info[run_id][iter_id] = {
                        'shape': df.shape,
                        'columns': list(df.columns),
                        'hash': data_hash,
                        'file_path': file_path
                    }
                    hashes[run_id][iter_id] = data_hash
                else:
                    print(f"   ❌ Failed to load {run_id}/{iter_id}")
        
        # Analyze consistency
        print("\n🔬 Analyzing consistency...")
        
        # Get all unique hashes
        all_hashes = []
        for run_hashes in hashes.values():
            all_hashes.extend(run_hashes.values())
        
        unique_hashes = set(all_hashes)
        
        results = {
            'total_files': total_files,
            'total_runs': len(files),
            'unique_hashes': len(unique_hashes),
            'all_identical': len(unique_hashes) == 1,
            'data_info': data_info,
            'hash_groups': {}
        }
        
        # Group files by hash
        for unique_hash in unique_hashes:
            results['hash_groups'][unique_hash] = []
            for run_id, run_hashes in hashes.items():
                for iter_id, file_hash in run_hashes.items():
                    if file_hash == unique_hash:
                        results['hash_groups'][unique_hash].append(f"{run_id}/{iter_id}")
        
        # Detailed comparison if there are differences
        if not results['all_identical']:
            results['detailed_comparisons'] = self._perform_detailed_comparisons(data_info)
        
        return results
    
    def _perform_detailed_comparisons(self, data_info: Dict) -> Dict[str, Any]:
        """
        Perform detailed comparisons between different data files
        
        Args:
            data_info: Information about all loaded data
            
        Returns:
            Detailed comparison results
        """
        comparisons = {}
        
        # Get first file as reference
        first_run = list(data_info.keys())[0]
        first_iter = list(data_info[first_run].keys())[0]
        reference_path = data_info[first_run][first_iter]['file_path']
        reference_df = self.load_test_data(reference_path)
        
        if reference_df is None:
            return {'error': 'Could not load reference data for comparison'}
        
        print(f"   Using {first_run}/{first_iter} as reference")
        
        # Compare all other files to reference
        for run_id, iterations in data_info.items():
            for iter_id, info in iterations.items():
                if run_id == first_run and iter_id == first_iter:
                    continue  # Skip reference file
                
                current_df = self.load_test_data(info['file_path'])
                if current_df is not None:
                    comparison = self.compare_dataframes(reference_df, current_df)
                    comparisons[f"{run_id}/{iter_id}"] = comparison
        
        return comparisons
    
    def print_summary(self, results: Dict[str, Any]) -> None:
        """
        Print a summary of the analysis results
        
        Args:
            results: Results from analyze_consistency()
        """
        if 'error' in results:
            print(f"❌ Error: {results['error']}")
            return
        
        print("\n" + "="*60)
        print("📊 TEST DATA CONSISTENCY ANALYSIS SUMMARY")
        print("="*60)
        
        print(f"📁 Total files analyzed: {results['total_files']}")
        print(f"🏃 Total runs: {results['total_runs']}")
        print(f"🔑 Unique data versions: {results['unique_hashes']}")
        
        if results['all_identical']:
            print("\n✅ RESULT: All test data files are IDENTICAL across all runs and iterations")
        else:
            print(f"\n❌ RESULT: Test data files are NOT identical - found {results['unique_hashes']} different versions")
            
            print("\n📊 Data version groups:")
            for i, (hash_key, file_list) in enumerate(results['hash_groups'].items(), 1):
                print(f"   Version {i} (hash: {hash_key[:8]}...): {len(file_list)} files")
                for file_id in file_list:
                    print(f"      - {file_id}")
            
            if 'detailed_comparisons' in results:
                print("\n🔍 Detailed differences:")
                for file_id, comparison in results['detailed_comparisons'].items():
                    if not comparison['identical']:
                        print(f"   {file_id}:")
                        for diff in comparison['differences']:
                            print(f"      - {diff}")
        
        # Show data structure info
        if results['data_info']:
            first_run = list(results['data_info'].keys())[0]
            first_iter = list(results['data_info'][first_run].keys())[0]
            first_info = results['data_info'][first_run][first_iter]
            
            print(f"\n📋 Data structure (from {first_run}/{first_iter}):")
            print(f"   Shape: {first_info['shape']}")
            print(f"   Columns: {len(first_info['columns'])} columns")
            if len(first_info['columns']) <= 10:
                print(f"   Column names: {first_info['columns']}")
            else:
                print(f"   First 10 columns: {first_info['columns'][:10]}")
    
    def generate_report(self, results: Dict[str, Any], output_file: str = 'test_data_consistency_report.txt') -> None:
        """
        Generate a detailed report file
        
        Args:
            results: Results from analyze_consistency()
            output_file: Output file path
        """
        with open(output_file, 'w') as f:
            f.write("TEST DATA CONSISTENCY ANALYSIS REPORT\n")
            f.write("="*50 + "\n\n")
            
            if 'error' in results:
                f.write(f"Error: {results['error']}\n")
                return
            
            f.write(f"Analysis Summary:\n")
            f.write(f"- Total files: {results['total_files']}\n")
            f.write(f"- Total runs: {results['total_runs']}\n")
            f.write(f"- Unique versions: {results['unique_hashes']}\n")
            f.write(f"- All identical: {results['all_identical']}\n\n")
            
            f.write("File Details:\n")
            for run_id, iterations in results['data_info'].items():
                f.write(f"\n{run_id}:\n")
                for iter_id, info in iterations.items():
                    f.write(f"  {iter_id}:\n")
                    f.write(f"    Shape: {info['shape']}\n")
                    f.write(f"    Hash: {info['hash']}\n")
                    f.write(f"    File: {info['file_path']}\n")
            
            if 'detailed_comparisons' in results:
                f.write("\nDetailed Comparisons:\n")
                for file_id, comparison in results['detailed_comparisons'].items():
                    f.write(f"\n{file_id}:\n")
                    f.write(f"  Identical: {comparison['identical']}\n")
                    if comparison['differences']:
                        f.write("  Differences:\n")
                        for diff in comparison['differences']:
                            f.write(f"    - {diff}\n")
        
        print(f"📄 Detailed report saved to: {output_file}")

# Example usage for SageMaker notebook
def run_analysis_example():
    """Example function showing how to use the analyzer"""
    
    # Initialize analyzer (assumes 'results' folder exists in current directory)
    analyzer = TestDataConsistencyAnalyzer('results')
    
    # Run analysis
    results = analyzer.analyze_consistency()
    
    # Print summary
    analyzer.print_summary(results)
    
    # Generate detailed report
    analyzer.generate_report(results)
    
    return results

if __name__ == "__main__":
    # Run example analysis
    results = run_analysis_example()
